{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 注意力评分函数\n",
    ":label:`sec_attention-scoring-functions`\n",
    "\n",
    "在 :numref:`sec_nadaraya-waston` 中，我们使用高斯核来对查询和键之间的关系建模。可以将 :eqref:`eq_nadaraya-waston-gaussian` 中的高斯核的指数部分视为 *注意力评分函数（attention scoring function）*，简称 *评分函数（scoring function）*，然后把这个函数的输出结果输入到 softmax 函数中进行运算。通过上述步骤，我们将得到与键对应的值的概率分布（即注意力权重）。最后，注意力汇聚的输出就是基于这些注意力权重的值的加权和。\n",
    "\n",
    "从宏观来看，可以使用上述算法来实现 :numref:`fig_qkv` 中的注意力机制框架。:numref:`fig_attention_output` 说明了如何将注意力汇聚的输出计算成为值的加权和，其中  𝑎  表示注意力评分函数。由于注意力权重是概率分布，因此加权和其本质上是加权平均值。\n",
    "\n",
    "![计算注意力汇聚的输出为值的加权和。](https://zh-v2.d2l.ai/_images/attention-output.svg)\n",
    ":label:`fig_attention_output`\n",
    "\n",
    "\n",
    "\n",
    "Mathematically,\n",
    "suppose that we have\n",
    "a query $\\mathbf{q} \\in \\mathbb{R}^q$\n",
    "and $m$ key-value pairs $(\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)$, where any $\\mathbf{k}_i \\in \\mathbb{R}^k$ and any $\\mathbf{v}_i \\in \\mathbb{R}^v$.\n",
    "The attention pooling $f$\n",
    "is instantiated as a weighted sum of the values:\n",
    "\n",
    "$$f(\\mathbf{q}, (\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)) = \\sum_{i=1}^m \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i \\in \\mathbb{R}^v,$$\n",
    ":eqlabel:`eq_attn-pooling`\n",
    "\n",
    "\n",
    "用数学语言描述，假设有一个查询 $\\mathbf{q} \\in \\mathbb{R}^q$ 和 $m$ 个键值对 $(\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)$，其中 $\\mathbf{k}_i \\in \\mathbb{R}^k$，$\\mathbf{v}_i \\in \\mathbb{R}^v$。注意力汇聚函数  𝑓  就被表示成值的加权和：\n",
    "\n",
    "$$f(\\mathbf{q}, (\\mathbf{k}_1, \\mathbf{v}_1), \\ldots, (\\mathbf{k}_m, \\mathbf{v}_m)) = \\sum_{i=1}^m \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i \\in \\mathbb{R}^v,$$\n",
    ":eqlabel:`eq_attn-pooling`\n",
    "\n",
    "其中查询 $\\mathbf{q}$ 和键 $\\mathbf{k}_i$ 的注意力权重（标量）是通过注意力评分函数 $a$ 将两个向量映射成标量，再经过 softmax 运算得到的：\n",
    "\n",
    "$$\\alpha(\\mathbf{q}, \\mathbf{k}_i) = \\mathrm{softmax}(a(\\mathbf{q}, \\mathbf{k}_i)) = \\frac{\\exp(a(\\mathbf{q}, \\mathbf{k}_i))}{\\sum_{j=1}^m \\exp(a(\\mathbf{q}, \\mathbf{k}_j))} \\in \\mathbb{R}.$$\n",
    ":eqlabel:`eq_attn-scoring-alpha`\n",
    "\n",
    "正如我们所看到的，选择不同的注意力评分函数 $a$ 会导致不同的注意力汇聚操作。在本节中，我们将介绍两个流行的评分函数，稍后将用他们来实现更复杂的注意力机制。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## 遮蔽softmax操作\n",
    "\n",
    "正如上面提到的，softmax 运算用于输出一个概率分布作为注意力权重。在某些情况下，并非所有的值都应该被纳入到注意力汇聚中。例如，为了在 :numref:`sec_machine_translation` 中高效处理小批量数据集，某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值去获取注意力汇聚，可以指定一个有效序列长度（即词元的个数），以便在计算 softmax 时过滤掉超出指定范围的位置。通过这种方式，我们可以在下面的 `maskedSoftmax` 函数中实现这样的 *遮蔽 softmax 操作（masked softmax operation）*，其中任何超出有效长度的位置都被遮蔽并置为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDArray maskedSoftmax(NDArray X, NDArray validLens) {\n",
    "    /* Perform softmax operation by masking elements on the last axis. */\n",
    "    // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray\n",
    "    if (validLens == null) {\n",
    "        return X.softmax(-1);\n",
    "    }\n",
    "    \n",
    "    Shape shape = X.getShape();\n",
    "    if (validLens.getShape().dimension() == 1) {\n",
    "        validLens = validLens.repeat(shape.get(1));\n",
    "    } else {\n",
    "        validLens = validLens.reshape(-1);\n",
    "    }\n",
    "    // On the last axis, replace masked elements with a very large negative\n",
    "    // value, whose exponentiation outputs 0\n",
    "    X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n",
    "            .sequenceMask(validLens, (float) -1E6);\n",
    "    return X.softmax(-1).reshape(shape);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "为了演示此函数是如何工作的，考虑由两个 $2 \\times 4$ 矩阵表示的样本，这两个样本的有效长度分别为2和3。经过遮蔽 softmax 操作，超出有效长度的值都被遮蔽为0。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maskedSoftmax(\n",
    "        manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
    "        manager.create(new float[] {2, 3}));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "同样，我们也可以使用二维张量为矩阵样本中的每一行指定有效长度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maskedSoftmax(\n",
    "        manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
    "        manager.create(new float[][] {{1, 3}, {2, 4}}));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## 加性注意力\n",
    ":label:`subsec_additive-attention`\n",
    "\n",
    "\n",
    "一般来说，当查询和键是不同长度的矢量时，可以使用加性注意力作为评分函数。给定查询 $\\mathbf{q} \\in \\mathbb{R}^q$\n",
    "和键 $\\mathbf{k} \\in \\mathbb{R}^k$，*加性注意力（additive attention）* 的评分函数为\n",
    "\n",
    "$$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\text{tanh}(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k) \\in \\mathbb{R},$$\n",
    ":eqlabel:`eq_additive-attn`\n",
    "\n",
    "其中可学习的参数是 $\\mathbf W_q\\in\\mathbb R^{h\\times q}$, $\\mathbf W_k\\in\\mathbb R^{h\\times k}$, and $\\mathbf w_v\\in\\mathbb R^{h}$ 。如 :eqref:`eq_additive-attn` 所示，将查询和键连接起来后输入到一个多层感知机（MLP）中，感知机包含一个隐藏层，其隐藏单元数是一个超参数 $h$。通过使用 $\\tanh$ 作为激活函数，并且禁用偏置项，我们将在下面实现加性注意力。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/* Additive attention. */\n",
    "public static class AdditiveAttention extends AbstractBlock {\n",
    "\n",
    "    private Linear W_k;\n",
    "    private Linear W_q;\n",
    "    private Linear W_v;\n",
    "    private Dropout dropout;\n",
    "    public NDArray attentionWeights;\n",
    "\n",
    "    public AdditiveAttention(int numHiddens, float dropout) {\n",
    "        W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
    "        addChildBlock(\"W_k\", W_k);\n",
    "\n",
    "        W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
    "        addChildBlock(\"W_q\", W_q);\n",
    "\n",
    "        W_v = Linear.builder().setUnits(1).optBias(false).build();\n",
    "        addChildBlock(\"W_v\", W_v);\n",
    "\n",
    "        this.dropout = Dropout.builder().optRate(dropout).build();\n",
    "        addChildBlock(\"dropout\", this.dropout);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore ps,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        // Shape of the output `queries` and `attentionWeights`:\n",
    "        // (no. of queries, no. of key-value pairs)\n",
    "        NDArray queries = inputs.get(0);\n",
    "        NDArray keys = inputs.get(1);\n",
    "        NDArray values = inputs.get(2);\n",
    "        NDArray validLens = inputs.get(3);\n",
    "\n",
    "        queries = W_q.forward(ps, new NDList(queries), training, params).head();\n",
    "        keys = W_k.forward(ps, new NDList(keys), training, params).head();\n",
    "        // After dimension expansion, shape of `queries`: (`batchSize`, no. of\n",
    "        // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,\n",
    "        // no. of key-value pairs, `numHiddens`). Sum them up with\n",
    "        // broadcasting\n",
    "        NDArray features = queries.expandDims(2).add(keys.expandDims(1));\n",
    "        features = features.tanh();\n",
    "        // There is only one output of `this.W_v`, so we remove the last\n",
    "        // one-dimensional entry from the shape. Shape of `scores`:\n",
    "        // (`batchSize`, no. of queries, no. of key-value pairs)\n",
    "        NDArray result = W_v.forward(ps, new NDList(features), training, params).head();\n",
    "        NDArray scores = result.squeeze(-1);\n",
    "        attentionWeights = maskedSoftmax(scores, validLens);\n",
    "        // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)\n",
    "        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n",
    "        return new NDList(list.head().batchDot(values));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void initializeChildBlocks(\n",
    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        W_q.initialize(manager, dataType, inputShapes[0]);\n",
    "        W_k.initialize(manager, dataType, inputShapes[1]);\n",
    "        long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();\n",
    "        long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();\n",
    "        long w = Math.max(q[q.length - 2], k[k.length - 2]);\n",
    "        long h = Math.max(q[q.length - 1], k[k.length - 1]);\n",
    "        long[] shape = new long[] {2, 1, w, h};\n",
    "        W_v.initialize(manager, dataType, new Shape(shape));\n",
    "        long[] dropoutShape = new long[shape.length - 1];\n",
    "        System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);\n",
    "        dropout.initialize(manager, dataType, new Shape(dropoutShape));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "让我们用一个小例子来演示上面的 `AdditiveAttention` 类，其中查询、键和值的形状为（批量大小、步数或词元序列长度、特征大小），实际输出为  ($2$, $1$, $20$), ($2$, $10$, $2$) 和 ($2$, $10$, $4$)。注意力汇聚输出的形状为（批量大小、查询的步数、值的维度）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32);\n",
    "NDArray keys = manager.ones(new Shape(2, 10, 2));\n",
    "// The two value matrices in the `values` minibatch are identical\n",
    "NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2);\n",
    "NDArray validLens = manager.create(new float[] {2, 6});\n",
    "\n",
    "AdditiveAttention attention = new AdditiveAttention(8, 0.1f);\n",
    "NDList input = new NDList(queries, keys, values, validLens);\n",
    "ParameterStore ps = new ParameterStore(manager, false);\n",
    "attention.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "attention.forward(ps, input, false).head();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "尽管加性注意力包含了可学习的参数，但由于本例子中每个键都是相同的，所以注意力权重是均匀的，由指定的有效长度决定。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PlotUtils.showHeatmaps(\n",
    "            attention.attentionWeights.reshape(1, 1, 2, 10),\n",
    "            \"Keys\",\n",
    "            \"Queries\",\n",
    "            new String[] {\"\"},\n",
    "            500,\n",
    "            700);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "## 缩放点积注意力\n",
    "\n",
    "使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的长度 $d$。假设查询和键的所有元素都是独立的随机变量，并且都满足均值为0和方差为。那么两个向量的点积的均值为0，方差为 $d$。为确保无论向量长度如何，点积的方差在不考虑向量长度的情况下仍然是1，则可以使用 *缩放点积注意力（scaled dot-product attention）* 评分函数：\n",
    "\n",
    "$$a(\\mathbf q, \\mathbf k) = \\mathbf{q}^\\top \\mathbf{k}  /\\sqrt{d}$$\n",
    "\n",
    "将点积除以 $\\sqrt{d}$。在实践中，我们通常从小批量的角度来考虑提高效率，例如基于 $n$ 个查询和 $m$ 个键－值对计算注意力，其中查询和键的长度为 $d$，值的长度为 $v$。查询 $\\mathbf Q\\in\\mathbb R^{n\\times d}$，键 $\\mathbf K\\in\\mathbb R^{m\\times d}$ 和值 $\\mathbf V\\in\\mathbb R^{m\\times v}$ 的缩放点积注意力是\n",
    "\n",
    "$$ \\mathrm{softmax}\\left(\\frac{\\mathbf Q \\mathbf K^\\top }{\\sqrt{d}}\\right) \\mathbf V \\in \\mathbb{R}^{n\\times v}.$$\n",
    ":eqlabel:`eq_softmax_QK_V`\n",
    "\n",
    "在下面的缩放点积注意力的实现中，我们使用了 dropout 进行模型正则化。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/* Scaled dot product attention. */\n",
    "public static class DotProductAttention extends AbstractBlock {\n",
    "\n",
    "    private Dropout dropout;\n",
    "    public NDArray attentionWeights;\n",
    "\n",
    "    public DotProductAttention(float dropout) {\n",
    "        this.dropout = Dropout.builder().optRate(dropout).build();\n",
    "        this.addChildBlock(\"dropout\", this.dropout);\n",
    "        this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore ps,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        // Shape of `queries`: (`batchSize`, no. of queries, `d`)\n",
    "        // Shape of `keys`: (`batchSize`, no. of key-value pairs, `d`)\n",
    "        // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n",
    "        // dimension)\n",
    "        // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)\n",
    "        NDArray queries = inputs.get(0);\n",
    "        NDArray keys = inputs.get(1);\n",
    "        NDArray values = inputs.get(2);\n",
    "        NDArray validLens = inputs.get(3);\n",
    "\n",
    "        Long d = queries.getShape().get(queries.getShape().dimension() - 1);\n",
    "        // Swap the last two dimensions of `keys` and perform batchDot\n",
    "        NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));\n",
    "        attentionWeights = maskedSoftmax(scores, validLens);\n",
    "        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n",
    "        return new NDList(list.head().batchDot(values));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void initializeChildBlocks(\n",
    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        try (NDManager sub = manager.newSubManager()) {\n",
    "            NDArray queries = sub.zeros(inputShapes[0], dataType);\n",
    "            NDArray keys = sub.zeros(inputShapes[1], dataType);\n",
    "            NDArray scores = queries.batchDot(keys.swapAxes(1, 2));\n",
    "            dropout.initialize(manager, dataType, scores.getShape());\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 23
   },
   "source": [
    "为了演示上述的 `DotProductAttention` 类，我们使用了与先前加性注意力例子中相同的键、值和有效长度。对于点积操作，令查询的特征维度与键的特征维度大小相同。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);\n",
    "DotProductAttention productAttention = new DotProductAttention(0.5f);\n",
    "input = new NDList(queries, keys, values, validLens);\n",
    "productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "productAttention.forward(ps, input, false).head();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 26
   },
   "source": [
    "与加性注意力演示相同，由于`键` 包含的是相同的元素，而这些元素无法通过任何查询进行区分，因此获得了均匀的注意力权重。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PlotUtils.showHeatmaps(\n",
    "        productAttention.attentionWeights.reshape(1, 1, 2, 10),\n",
    "        \"Keys\",\n",
    "        \"Queries\",\n",
    "        new String[] {\"\"},\n",
    "        500,\n",
    "        700);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 28
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 可以将注意力汇聚的输出计算作为值的加权平均，选择不同的注意力评分函数会带来不同的注意力汇聚操作。\n",
    "* 当查询和键是不同长度的矢量时，可以使用可加性注意力评分函数。当它们的长度相同时，使用缩放的“点－积”注意力评分函数的计算效率更高。\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 修改小例子中的键，并且可视化注意力权重。可加性注意力和缩放的“点－积”注意力是否仍然产生相同的结果？为什么？\n",
    "1. 只使用矩阵乘法，您能否为具有不同矢量长度的查询和键设计新的评分函数？\n",
    "1. 当查询和键具有相同的矢量长度时，矢量求和作为评分函数是否比“点－积”更好？为什么？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
